{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MNIST-CNN\n",
    "This notebook shows how to perform handwritten digit classification using the CNN implemented with spooNN."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Preparation\n",
    "First we import all the necessary components:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "import math\n",
    "import numpy as np\n",
    "import os\n",
    "import time\n",
    "from PIL import Image\n",
    "from matplotlib import pyplot\n",
    "import cv2\n",
    "from datetime import datetime\n",
    "from pynq import Xlnk\n",
    "from pynq import Overlay\n",
    "from pynq.mmio import MMIO\n",
    "from loader import loader\n",
    "import scipy.misc\n",
    "from IPython.display import display"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we get the overlay that is in the same directory as this notebook. This is loaded onto the FPGA fabric on the PYNQ.\n",
    "Then, a handle to the spooNN IP is obtained via nn_ctrl. Using this, we are able to write to the registers of the IP to configure it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Got nn_ctrl!\n"
     ]
    }
   ],
   "source": [
    "OVERLAY_PATH = 'overlay.bit'\n",
    "overlay = Overlay(OVERLAY_PATH)\n",
    "dma = overlay.axi_dma_0\n",
    "\n",
    "xlnk = Xlnk()\n",
    "nn_ctrl = MMIO(0x43c00000, length=1024)\n",
    "print('Got nn_ctrl!')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we set some parameters. MINIBATCH_SIZE is how many images we want to classify in one go. The rest of the parameters are related to image properties (28 by 28 images, 8 bits per pixel, and 448 bits per line resulting in 14 total lines to transmit one image to the FPGA)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_lines per image: 14\n"
     ]
    }
   ],
   "source": [
    "## Parameters\n",
    "MINIBATCH_SIZE = 10\n",
    "height = 28\n",
    "width = 28\n",
    "pixel_bits = 8\n",
    "pixels_per_line = 448/pixel_bits\n",
    "num_lines = int((height*width)/pixels_per_line)\n",
    "print('num_lines per image: ' + str(num_lines))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's download the data set if it is not yet there, by executing the following bash script. If you don't have internet access from your PYNQ, you might have to execute this script on your host machine and then scp mnist.t to this directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[mnist.t] Nothing to do. Data is there.\r\n"
     ]
    }
   ],
   "source": [
    "!chmod u+x get_mnist.sh\n",
    "!./get_mnist.sh"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we can load the images we want to classify to the memory and display them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAAxElEQVR4nGNgGDaAEUqHpD77sfTF\nHeyS9xQYGBg+X4XynnSdYWBgYGCBclP1r2kZOlg8lmVg+PNakuHRGWSdDAwMDAyChmdMGRh+3Lou\nlDMNlzuC/14UwiUn9vJ/MITFhCmZLfr+Ji6N1j//2eGSY2j9t5sVlxzn2R9WODXW/duGU8779wdL\nXHLCd/8twyXHfPrfbWVckmr//vnikpN/8K+YEZdk679/JrjkbD+hSSKHrQ0Pw90vyJIsyByGi87v\ncBmLAQAskTyQBhizGgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=28x28 at 0x279F2E70>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA5UlEQVR4nGNgoAlgRGKXcOqFMEw/\nvhibupV///79+/fvLTkcclf7Nvz9W4UpZ/Lr7yUFHga2c397MCV9f1+SZGBgqPrx1xGLsfJCDAwM\nDBf/YpVkYGBgYCj9/vcYFw45n+9/n9sjuEyormJjWHkQh8YN3/7O58EhJ/nq70tlXK459vdvLy45\nvx9/9+IyVPgEHo1tf/+uxaWR4cffv5JoQizIHKHfDAwMH3+z8jMIFjIw/C1HkbzEwMDAsPq5eDiE\n+wKREtb5I1T9+cew6QzDEaRkUsbKwKAdzsAw7wHDuuu4XDawAAB8elD8paDXZQAAAABJRU5ErkJg\ngg==\n",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=28x28 at 0x279F2ED0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAAiElEQVR4nGNgGAxA7V8ujMmEIWn4\n7ylune0fccvpfpkGZ2MYq861ErfOU/e5ccop/LuB4KAba8/wGrekLkMXTlMt357lwKnTWejGD5yS\n+v/X4DRV4sV1ZC6qzgSxE7gl5Rne4zSV4ck/J5w6bcVx62Po/XeWGZdOLi+GNX9xaWQ9toELj7n0\nAQDvyR/A5hSKUQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=28x28 at 0x279F2E90>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA8ElEQVR4nMXQsWoCQRQF0Ougjbax\n3Q1I8AOEpHK3VItAfiNIBP9ASyGEgGDtD2hppY1pEjurkKApQljBNuXlYop10d1Z2/iqO3NmHrwH\nnLsKAy3cU3hFqnk4mmMrDuNvs0e5dXcNwDPLeUpTkaTIVcW2yU6Stl+SLPPXItm/9TrkfcIuNxRX\nvTzgBvxt56wpphcAgAeKpSS+OWF0X/d4mNOYm+8wZYwx3djPRzKKdtuPCIv+hgqcdHwmua4CiK8P\nADApA3h/iV9+So3Gj7RTyobaZLRb9pPoBnsMZqV8EuE9hdi0BABQH3NUqzvp+F/1B94Of+XHOAIC\nAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=28x28 at 0x279F2E90>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA2klEQVR4nGNgGMTA6AGU4SYLoZmQ\nJN3ZoQy/bgxJFi8Y64wWN7qko+V8KEtIiwvNRt03N3mgzAN/RNEkV3w3hWn8/xdNMuTTZRiz9+9e\nVlTJlX+yoCyFF7+cUOX4H/6BMdv+wMyAuZZdegVMUpnhCppzOM9cFIKwxP7+zYYKskDp73eDt/Yx\nMDDoKMv/Z/iPppNBc9XXP3/+/Hnx/PefP5xQMUaEtKEyAwPDGoaF0XDjMEH9nz+6aHYiACMjw2Wc\nkv8RzmHCkORg+IHTSoYXb/JxS252wi1HHQAAwkk8HXcuvigAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=28x28 at 0x279F2E70>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAAnUlEQVR4nGNgGPTA9382My454Sf/\n/3Pikgz8/38ZIw459jP//3vi0mj6//9vOIcJTTKIgWEXLo0MR///NMAlZ/X//zsED81YUwaG6ThN\nXfL/vQwuOZu//x8gcVGNFWZi2I3PVFNccjJ//19G5qMYa8XEsBGnqZn/X4vg1OnO8OgjLklWFYYf\nv3FJ/jvNcAfFGhYk9t/q/2dxuoeKAACi1i+vugWMtgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=28x28 at 0x279F2E90>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA10lEQVR4nGNgGLxA7ECbAozN78uK\nIif4+tdKuNydj6rIciJ7/06Gc7r/pqBodPv7VxTG1v63lhfFwpl/E+Byz//FoGhc/P8MN4yd8W8e\nqlMX/d0EdR9n89u/UEEWuLT3rg/TGRgY7B0sGNagamQwfvL377+/f//+/ff3721lNJ1ndQ08Sl8v\nZGBYfJHh2F0GHEDp3zlRXHIMC/664pQL/ffRCKfkvH9LccoxPP+CW2PGvxe4NV74O5eBVw7GY0KX\n/hu9vxmnzn9/Z8nikLTd1yDOhttW6gAA3VFLfcoS/1AAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=28x28 at 0x279F2E70>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA1klEQVR4nGNgGLKAb+pBVlxy0Q/+\n/hXGISfz+t/fv8uEsEtO+Pvv79+/74rZsMjJf/x7Yeffv3+fS2CR9P93kIEj6fa//yfhJjPBJdn/\n9zP8mHf7//9vvzAlIxm8GRgYTBgYTnzBNDbs7wWN0GW/3/57o4UpKfTu77+/f3eq3Pg7A4uLXD78\n/zeRg6Ht331lbLLz+ngYGDjX/12IRRIKIv4+whFMDAwMTMv+1uHWavD1rxpu2eJ/azhxSore+qeH\nW6vcv6W4JRl2fcESTDDAd98Pj1bCAADs81IBY4DV8QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=28x28 at 0x279F2E90>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA6ElEQVR4nGNgoD9gxBQSlGNgeFh4\n5dZFTCnvmTf//v17/dvfv+gyyn1f/vyFAQYGFhRJmXwIfeMqAwMDsqRIwZEdvz5+5d515eT5719R\nTeQ+99ePgUGBQY4J0x1sG/+2cOHwEU/r35f8OOQYYv7el8EUhdpgxXD+CS6NDK/+fqs3xBCFBt//\nfwwM/2ackLtzlUH7OLoZ3X8R4MUKNElm01v3fsNk/9Rg2uvscQIqux6bsyr+/pxhvASHpNHfv3/3\n/Pn7dwo2Sc7lf//+/ftrPTc2SQbxLc//3m3AKsXAwMAQO1UMpxxdAACQIXbQPZUQoQAAAABJRU5E\nrkJggg==\n",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=28x28 at 0x279F2E70>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABBElEQVR4nN3QsS8DcRjG8UcOPUFi\nINLNYGgj2M7QpTGISSIxWRkMFomIGESCwWJ1XTtYDP4Ag60xoqNBqjW0IWlPQ/Tc9yyGa/P7/QOe\n7c3nfYb3lf5N+rrH6fHVPH7pybA56zcA6JQvBntortCE6uVpeMdrZavbCg24OXel20zp8zmaSJB7\nGFE/Gpakx5klIInLH9Q8SXKmdisBcXFAUv8fOpHChbWMvrLZt0mpfhImmkPX7SiGEICfq3TPnWP7\nufeX1LwnyT9oGp9ShNaGYyTtdWDdTNoMoJwym9eCIGcpHkM7b7HRb/AtNlKDe9eCKzEsWkwPcGYz\nVeN62oo78bbVDPkFEdlychOgtWAAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=28x28 at 0x279F2E70>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Read images from file\n",
    "l = loader()\n",
    "l.load_libsvm_data('mnist.t', num_samples=MINIBATCH_SIZE, num_features=784, one_hot=0, classes=None)\n",
    "images = np.zeros((MINIBATCH_SIZE,28,28))\n",
    "for i in range(0, MINIBATCH_SIZE):\n",
    "    images[i,:,:] = (l.a[i].reshape((28,28))).astype('int')\n",
    "    \n",
    "# Display images\n",
    "for i in range(0,MINIBATCH_SIZE):\n",
    "    display(scipy.misc.toimage( images[i,:,:] ))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Execution\n",
    "Now we are ready to invoke the FPGA-based CNN to perform classification on the images that are in the memory.\n",
    "\n",
    "- We first allocate the buffers that will be used by the DMA engine, to transfer images to the FPGA and get the results back.\n",
    "- We write the images to the input buffer.\n",
    "- We set a register (numReps) on the mnist-cnn IP, by writing MINIBATCH_SIZE to address 0x10. This is how many images the IP expects the DMA engine to send.\n",
    "- We start the transfer on both receive and send channels.\n",
    "- After the execution is complete, we display the classification results. Notice that they match the images we displayed earlier. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "allocated buffers\n",
      "Time per image: 0.00024547576904296873 s\n",
      "Images per second: 4073.721833721834\n",
      "7\n",
      "2\n",
      "1\n",
      "0\n",
      "4\n",
      "1\n",
      "4\n",
      "9\n",
      "5\n",
      "9\n"
     ]
    }
   ],
   "source": [
    "in_buffer = xlnk.cma_array(shape=(MINIBATCH_SIZE*num_lines, 64), dtype=np.uint8)\n",
    "out_buffer = xlnk.cma_array(shape=(MINIBATCH_SIZE, 16), dtype=np.int32)\n",
    "print('allocated buffers')\n",
    "\n",
    "for i in range(0,MINIBATCH_SIZE):\n",
    "    in_buffer[i*num_lines:(i+1)*num_lines, 0:56] = np.reshape(images[i,:,:], (num_lines, 56))\n",
    "\n",
    "start = time.time()\n",
    "nn_ctrl.write(0x0, 0) # Reset\n",
    "nn_ctrl.write(0x10, MINIBATCH_SIZE)\n",
    "nn_ctrl.write(0x0, 1) # Deassert reset\n",
    "dma.recvchannel.transfer(out_buffer)\n",
    "dma.sendchannel.transfer(in_buffer)\n",
    "end = time.time()\n",
    "\n",
    "time_per_image = (end-start)/MINIBATCH_SIZE\n",
    "print(\"Time per image: \" + str(time_per_image) + \" s\")\n",
    "print(\"Images per second: \" + str(1.0/time_per_image))\n",
    "\n",
    "time.sleep(1)\n",
    "\n",
    "for i in range(0,MINIBATCH_SIZE):\n",
    "    print(str(np.argmax(out_buffer[i,:])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Since we worked only on 10 images in this notebook, the maximum image processing throughput capability of the CNN is not displayed. Try increasing the MINIBATCH_SIZE to observe a much higher throughput (You might want to comment out display part to skip displaying many images). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
